# Mycroft Server - Backend
# Copyright (C) 2019 Mycroft AI Inc
# SPDX-License-Identifier: 	AGPL-3.0-or-later
#
# This file is part of the Mycroft Server.
#
# The Mycroft Server is free software: you can redistribute it and/or
# modify it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""Python code to support the device pairing feature."""
import json
import uuid

from unittest.mock import patch, MagicMock

from behave import given, then, when  # pylint: disable=no-name-in-module
from hamcrest import assert_that, equal_to, has_key, none, not_none

from selene.data.device import DeviceRepository
from selene.util.cache import DEVICE_PAIRING_CODE_KEY, DEVICE_PAIRING_TOKEN_KEY

ONE_MINUTE = 60
ONE_DAY = 86400


@given("the user completes the pairing process on the web application")
def add_device(context):
    """Imitate the logic in the account API to pair a device"""
    context.pairing_token = "pairing_token"
    context.pairing_state = "pairing_state"
    pairing_data = dict(
        code="ABC123",
        uuid=context.device_id,
        state=context.pairing_state,
        token=context.pairing_token,
        expiration=ONE_DAY,
    )
    context.cache.set_with_expiration(
        key=DEVICE_PAIRING_TOKEN_KEY.format(pairing_token=context.pairing_token),
        value=json.dumps(pairing_data),
        expiration=ONE_MINUTE,
    )


@when("a device requests a pairing code")
def get_device_pairing_code(context):
    """Call the endpoint that generates the pairing data."""
    context.state = str(uuid.uuid4())
    response = context.client.get(
        "/v1/device/code?state={state}&packaging=pantacor".format(state=context.state)
    )
    context.response = response


@when("the device requests to be activated")
def activate_device(context):
    """Call the endpoint that completes the device registration process.

    This call is for devices that are not managed by Pantacor.
    """
    activation_request = dict(
        token=context.pairing_token,
        state=context.pairing_state,
        platform="test_platform",
        core_version="test_core_version",
        enclosure_version="test_enclosure_version",
    )
    response = context.client.post(
        "/v1/device/activate",
        data=json.dumps(activation_request),
        content_type="application/json",
    )
    context.response = response


@when("Pantacor has not yet claimed the device")
def set_pantacor_not_claimed(context):
    context.pantacor_claimed = False


@when("Pantacor has claimed the device")
def set_pantacor_not_claimed(context):
    context.pantacor_claimed = True


@when("a device requests to sync with Pantacor")
def activate_pantacor_device(context):
    """Call the endpoint that completes the device registration process.

    This call is for devices that are managed by Pantacor.
    """
    login = context.device_login
    device_id = login["uuid"]
    pantacor_request = dict(
        mycroft_device_id=device_id, pantacor_device_id="test_pantacor_id"
    )
    with patch("requests.request") as request_patch:
        get_channel_response = _mock_get_channel_response()
        get_device_response = _mock_get_device_response(context)
        request_patch.side_effect = [get_channel_response, get_device_response]
        response = context.client.post(
            "/v1/device/pantacor",
            data=json.dumps(pantacor_request),
            content_type="application/json",
            headers=context.request_header,
        )
        context.response = response


def _mock_get_channel_response() -> MagicMock:
    """Mock the response that would be generated by the Pantacor API channel endpoint.

    Ideally, there would be a test device setup so we could test without mocking.
    Until then, here we are.
    """
    get_channel_content = dict(
        items=[dict(id="test_channel_id", name="test_channel_name")]
    )
    get_channel_response = MagicMock(spec=["ok", "content"])
    get_channel_response.ok = True
    get_channel_response.content = json.dumps(get_channel_content).encode()

    return get_channel_response


def _mock_get_device_response(context) -> MagicMock:
    """Mock the response that would be generated by the Pantacor API device endpoint.

    Ideally, there would be a test device setup so we could test without mocking.
    Until then, here we are.
    """
    labels = [
        "device-meta/interfaces.wlan0.ipv4.0=192.168.1.2",
        f"device-meta/pantahub.claimed={1 if context.pantacor_claimed else 0}",
    ]
    get_device_content = dict(
        id="test_device_id",
        channel_id="test_channel_id",
        update_policy="auto",
        labels=labels,
    )
    get_device_response = MagicMock(spec=["ok", "content"])
    get_device_response.ok = True
    get_device_response.content = json.dumps(get_device_content).encode()

    return get_device_response


@then("the pairing data is stored in Redis")
def check_cached_pairing_data(context):
    """Confirm that the pairing data stored in Redis is as expected."""
    pairing_code_key = DEVICE_PAIRING_CODE_KEY.format(
        pairing_code=context.response.json["code"]
    )
    pairing_data = context.cache.get(pairing_code_key)
    pairing_data = json.loads(pairing_data)
    context.cache.delete(pairing_code_key)
    assert_that(pairing_data, has_key("token"))
    assert_that(pairing_data["code"], equal_to(context.response.json["code"]))
    assert_that(pairing_data["expiration"], equal_to(ONE_DAY))
    assert_that(pairing_data["state"], equal_to(context.state))
    assert_that(pairing_data["packaging_type"], equal_to("pantacor"))


@then("the pairing data is sent to the device")
def validate_pairing_code_response(context):
    """Check that the endpoint returns the expected pairing data to the device"""
    response = context.response
    assert_that(response.json, has_key("code"))
    assert_that(response.json, has_key("token"))
    assert_that(response.json["expiration"], equal_to(ONE_DAY))
    assert_that(response.json["state"], equal_to(context.state))


@then("the activation data is sent to the device")
def validate_activation_response(context):
    """Check that the endpoint returns the expected activation data to the device."""
    response = context.response
    assert_that(response.json["uuid"], equal_to(context.device_id))
    assert_that(response.json, has_key("accessToken"))
    assert_that(response.json, has_key("refreshToken"))
    assert_that(response.json["expiration"], equal_to(ONE_DAY))


@then("the device attributes are stored in the database")
def validate_device_update(context):
    """Validate that the non-Pantacor device attributes are updated correctly."""
    device_repo = DeviceRepository(context.db)
    device = device_repo.get_device_by_id(context.device_id)
    assert_that(device.core_version, equal_to("test_core_version"))
    assert_that(device.platform, equal_to("test_platform"))
    assert_that(device.enclosure_version, equal_to("test_enclosure_version"))


@then("the Pantacor device configuration is stored in the database")
def validate_pantacor_update(context):
    """Validate that the Pantacor config of the device is stored in the database."""
    device_repo = DeviceRepository(context.db)
    device = device_repo.get_device_by_id(context.device_id)
    assert_that(device.pantacor_config, not_none())
    assert_that(device.pantacor_config.pantacor_id, equal_to("test_pantacor_id"))
    assert_that(device.pantacor_config.release_channel, equal_to("test_channel_name"))
    assert_that(device.pantacor_config.auto_update, equal_to(True))
    assert_that(device.pantacor_config.ip_address, equal_to("192.168.1.2"))
    assert_that(device.pantacor_config.ssh_public_key, none())
